{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "+++\n",
    "title = \"Generalized Procrustes analysis\"\n",
    "menu = \"main\"\n",
    "weight = 6\n",
    "toc = true\n",
    "aliases = [\"gpa\"]\n",
    "+++"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Resources\n",
    "\n",
    "🤷‍♂️"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User guide\n",
    "\n",
    "Generalized procrustes analysis (GPA) is a shape analysis tool that aligns and scales a set of shapes to a common reference. Here, the term \"shape\" means an *ordered* sequence of points. GPA iteratively 1) aligns each shape with a reference shape (usually the mean shape), 2) then updates the reference shape, 3) repeating until converged.\n",
    "\n",
    "Note that the final rotation of the aligned shapes may vary between runs, based on the initialization.\n",
    "\n",
    "Here is an example aligning a few right triangles:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:01.633858Z",
     "iopub.status.busy": "2024-09-07T18:18:01.633426Z",
     "iopub.status.idle": "2024-09-07T18:18:01.668183Z",
     "shell.execute_reply": "2024-09-07T18:18:01.667693Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>shape</th>\n",
       "      <th>point</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     x    y  shape  point\n",
       "0  0.0  0.0      0      0\n",
       "1  0.0  2.0      0      1\n",
       "2  1.0  0.0      0      2"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "points = pd.DataFrame(\n",
    "    data=[\n",
    "        [0, 0, 0, 0],\n",
    "        [0, 2, 0, 1],\n",
    "        [1, 0, 0, 2],\n",
    "        [3, 2, 1, 0],\n",
    "        [1, 2, 1, 1],\n",
    "        [3, 3, 1, 2],\n",
    "        [0, 0, 2, 0],\n",
    "        [0, 4, 2, 1],\n",
    "        [2, 0, 2, 2],\n",
    "    ],\n",
    "    columns=['x', 'y', 'shape', 'point']\n",
    ").astype({'x': float, 'y': float})\n",
    "points.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:01.670859Z",
     "iopub.status.busy": "2024-09-07T18:18:01.670666Z",
     "iopub.status.idle": "2024-09-07T18:18:01.755157Z",
     "shell.execute_reply": "2024-09-07T18:18:01.754860Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<div id=\"altair-viz-163523f3e6ec4bb485bf25d7a35cc5ec\"></div>\n",
       "<script type=\"text/javascript\">\n",
       "  var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
       "  (function(spec, embedOpt){\n",
       "    let outputDiv = document.currentScript.previousElementSibling;\n",
       "    if (outputDiv.id !== \"altair-viz-163523f3e6ec4bb485bf25d7a35cc5ec\") {\n",
       "      outputDiv = document.getElementById(\"altair-viz-163523f3e6ec4bb485bf25d7a35cc5ec\");\n",
       "    }\n",
       "    const paths = {\n",
       "      \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
       "      \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
       "      \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n",
       "      \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
       "    };\n",
       "\n",
       "    function maybeLoadScript(lib, version) {\n",
       "      var key = `${lib.replace(\"-\", \"\")}_version`;\n",
       "      return (VEGA_DEBUG[key] == version) ?\n",
       "        Promise.resolve(paths[lib]) :\n",
       "        new Promise(function(resolve, reject) {\n",
       "          var s = document.createElement('script');\n",
       "          document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "          s.async = true;\n",
       "          s.onload = () => {\n",
       "            VEGA_DEBUG[key] = version;\n",
       "            return resolve(paths[lib]);\n",
       "          };\n",
       "          s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
       "          s.src = paths[lib];\n",
       "        });\n",
       "    }\n",
       "\n",
       "    function showError(err) {\n",
       "      outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
       "      throw err;\n",
       "    }\n",
       "\n",
       "    function displayChart(vegaEmbed) {\n",
       "      vegaEmbed(outputDiv, spec, embedOpt)\n",
       "        .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
       "    }\n",
       "\n",
       "    if(typeof define === \"function\" && define.amd) {\n",
       "      requirejs.config({paths});\n",
       "      require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
       "    } else {\n",
       "      maybeLoadScript(\"vega\", \"5\")\n",
       "        .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
       "        .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
       "        .catch(showError)\n",
       "        .then(() => displayChart(vegaEmbed));\n",
       "    }\n",
       "  })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}}, \"data\": {\"name\": \"data-28b7252a2dc4b45238c1274f5dffb9c6\"}, \"mark\": {\"type\": \"line\", \"opacity\": 0.5}, \"encoding\": {\"color\": {\"field\": \"shape\", \"type\": \"nominal\"}, \"detail\": {\"field\": \"shape\", \"type\": \"quantitative\"}, \"x\": {\"field\": \"x\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"y\", \"type\": \"quantitative\"}}, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\", \"datasets\": {\"data-28b7252a2dc4b45238c1274f5dffb9c6\": [{\"x\": 0.0, \"y\": 0.0, \"shape\": 0, \"point\": 0}, {\"x\": 0.0, \"y\": 2.0, \"shape\": 0, \"point\": 1}, {\"x\": 1.0, \"y\": 0.0, \"shape\": 0, \"point\": 2}, {\"x\": 3.0, \"y\": 2.0, \"shape\": 1, \"point\": 0}, {\"x\": 1.0, \"y\": 2.0, \"shape\": 1, \"point\": 1}, {\"x\": 3.0, \"y\": 3.0, \"shape\": 1, \"point\": 2}, {\"x\": 0.0, \"y\": 0.0, \"shape\": 2, \"point\": 0}, {\"x\": 0.0, \"y\": 4.0, \"shape\": 2, \"point\": 1}, {\"x\": 2.0, \"y\": 0.0, \"shape\": 2, \"point\": 2}]}}, {\"mode\": \"vega-lite\"});\n",
       "</script>"
      ],
      "text/plain": [
       "alt.Chart(...)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import altair as alt\n",
    "\n",
    "alt.Chart(points).mark_line(opacity=0.5).encode(\n",
    "    x='x',\n",
    "    y='y',\n",
    "    detail='shape',\n",
    "    color='shape:N'\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataframe of points has to converted to a 3D numpy array of shape `(shapes, points, dims)`. There are many ways to do this. Here, we use xarray as a helper package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:01.756840Z",
     "iopub.status.busy": "2024-09-07T18:18:01.756743Z",
     "iopub.status.idle": "2024-09-07T18:18:01.807548Z",
     "shell.execute_reply": "2024-09-07T18:18:01.807313Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 3, 2)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds = points.set_index(['shape', 'point']).to_xarray()\n",
    "da = ds.to_stacked_array('xy', ['shape', 'point'])\n",
    "shapes = da.values\n",
    "shapes.shape"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This can also be done in NumPy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:01.809002Z",
     "iopub.status.busy": "2024-09-07T18:18:01.808906Z",
     "iopub.status.idle": "2024-09-07T18:18:01.818337Z",
     "shell.execute_reply": "2024-09-07T18:18:01.818121Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 3, 2)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "gb = points.groupby('shape')\n",
    "np.stack([gb.get_group(g)[['x', 'y']] for g in gb.groups]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:01.819667Z",
     "iopub.status.busy": "2024-09-07T18:18:01.819581Z",
     "iopub.status.idle": "2024-09-07T18:18:01.826748Z",
     "shell.execute_reply": "2024-09-07T18:18:01.826491Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0., 0.],\n",
       "        [0., 2.],\n",
       "        [1., 0.]],\n",
       "\n",
       "       [[3., 2.],\n",
       "        [1., 2.],\n",
       "        [3., 3.]],\n",
       "\n",
       "       [[0., 0.],\n",
       "        [0., 4.],\n",
       "        [2., 0.]]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shapes"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The shapes can now be aligned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:01.828533Z",
     "iopub.status.busy": "2024-09-07T18:18:01.828396Z",
     "iopub.status.idle": "2024-09-07T18:18:02.157698Z",
     "shell.execute_reply": "2024-09-07T18:18:02.157289Z"
    }
   },
   "outputs": [],
   "source": [
    "import prince\n",
    "\n",
    "gpa = prince.GPA()\n",
    "aligned_shapes = gpa.fit_transform(shapes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then convert the 3D numpy array to a dataframe (using `xarray`) for plotting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-07T18:18:02.159473Z",
     "iopub.status.busy": "2024-09-07T18:18:02.159364Z",
     "iopub.status.idle": "2024-09-07T18:18:02.187045Z",
     "shell.execute_reply": "2024-09-07T18:18:02.186796Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<div id=\"altair-viz-36f965231bd64457b172e6c0c8c0f7b3\"></div>\n",
       "<script type=\"text/javascript\">\n",
       "  var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
       "  (function(spec, embedOpt){\n",
       "    let outputDiv = document.currentScript.previousElementSibling;\n",
       "    if (outputDiv.id !== \"altair-viz-36f965231bd64457b172e6c0c8c0f7b3\") {\n",
       "      outputDiv = document.getElementById(\"altair-viz-36f965231bd64457b172e6c0c8c0f7b3\");\n",
       "    }\n",
       "    const paths = {\n",
       "      \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
       "      \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
       "      \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.17.0?noext\",\n",
       "      \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
       "    };\n",
       "\n",
       "    function maybeLoadScript(lib, version) {\n",
       "      var key = `${lib.replace(\"-\", \"\")}_version`;\n",
       "      return (VEGA_DEBUG[key] == version) ?\n",
       "        Promise.resolve(paths[lib]) :\n",
       "        new Promise(function(resolve, reject) {\n",
       "          var s = document.createElement('script');\n",
       "          document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "          s.async = true;\n",
       "          s.onload = () => {\n",
       "            VEGA_DEBUG[key] = version;\n",
       "            return resolve(paths[lib]);\n",
       "          };\n",
       "          s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
       "          s.src = paths[lib];\n",
       "        });\n",
       "    }\n",
       "\n",
       "    function showError(err) {\n",
       "      outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
       "      throw err;\n",
       "    }\n",
       "\n",
       "    function displayChart(vegaEmbed) {\n",
       "      vegaEmbed(outputDiv, spec, embedOpt)\n",
       "        .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
       "    }\n",
       "\n",
       "    if(typeof define === \"function\" && define.amd) {\n",
       "      requirejs.config({paths});\n",
       "      require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
       "    } else {\n",
       "      maybeLoadScript(\"vega\", \"5\")\n",
       "        .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
       "        .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
       "        .catch(showError)\n",
       "        .then(() => displayChart(vegaEmbed));\n",
       "    }\n",
       "  })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}}, \"data\": {\"name\": \"data-c02c635fccec02f1185693da22cf110f\"}, \"mark\": {\"type\": \"line\", \"opacity\": 0.5}, \"encoding\": {\"color\": {\"field\": \"shape\", \"type\": \"nominal\"}, \"detail\": {\"field\": \"shape\", \"type\": \"quantitative\"}, \"x\": {\"field\": \"x\", \"type\": \"quantitative\"}, \"y\": {\"field\": \"y\", \"type\": \"quantitative\"}}, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.17.0.json\", \"datasets\": {\"data-c02c635fccec02f1185693da22cf110f\": [{\"shape\": 0, \"point\": 0, \"x\": -0.1825741858350555, \"y\": -0.36514837167011066}, {\"shape\": 0, \"point\": 1, \"x\": -0.1825741858350552, \"y\": 0.7302967433402214}, {\"shape\": 0, \"point\": 2, \"x\": 0.36514837167011066, \"y\": -0.3651483716701107}, {\"shape\": 1, \"point\": 0, \"x\": -0.1825741858350555, \"y\": -0.3651483716701106}, {\"shape\": 1, \"point\": 1, \"x\": -0.18257418583505522, \"y\": 0.7302967433402213}, {\"shape\": 1, \"point\": 2, \"x\": 0.36514837167011066, \"y\": -0.36514837167011077}, {\"shape\": 2, \"point\": 0, \"x\": -0.1825741858350555, \"y\": -0.36514837167011066}, {\"shape\": 2, \"point\": 1, \"x\": -0.1825741858350552, \"y\": 0.7302967433402214}, {\"shape\": 2, \"point\": 2, \"x\": 0.36514837167011066, \"y\": -0.3651483716701107}]}}, {\"mode\": \"vega-lite\"});\n",
       "</script>"
      ],
      "text/plain": [
       "alt.Chart(...)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "da.values = aligned_shapes\n",
    "aligned_points = da.to_unstacked_dataset('xy').to_dataframe().reset_index()\n",
    "\n",
    "alt.Chart(aligned_points).mark_line(opacity=0.5).encode(\n",
    "    x='x',\n",
    "    y='y',\n",
    "    detail='shape',\n",
    "    color='shape:N'\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The triangles were all the same shape, so they are now perfectly aligned."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "441c2ec70d9faeb70e7723f55150c6260f4a26a9c828b90915d3399002e14f43"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
